Conversation
- Add SplashAttentionConfig for configuring splash attention parameters - Integrate splash attention into Transformer, TransformerEncoder, and VisionTransformerBase - Update CLIP, SigLIP, and VisionTransformer models to accept splash_attention_config - Add tokamax>=0.0.9 dependency and update Python requirement to >=3.11 - Add tests for splash attention configuration - Update README with usage documentation
…ention and regular attention models, updated README
|
/gemini |
Summary of ChangesHello @pythoncrazy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces experimental support for TPU-optimized Splash Attention via the tokamax library. The changes are well-structured, adding a new configuration class and integrating it throughout the model hierarchy. The addition of documentation and tests is also a great practice. My review focuses on improving the consistency of the new create_splash_attention_fn API and simplifying some of the integration logic for better maintainability.
| ) -> Callable[..., Float[Array, "batch heads seq head_dim"]]: | ||
| """Create a splash attention function compatible with nnx.MultiHeadAttention. | ||
|
|
||
| Args: | ||
| config (SplashAttentionConfig): Splash attention configuration. | ||
| num_heads (int): Number of attention heads. | ||
| head_dim (int): Dimension of each attention head. | ||
|
|
||
| Returns: | ||
| Callable: An attention function. Returns splash attention if enabled and available, | ||
| otherwise returns the default dot_product_attention. | ||
| """ | ||
| if not _TOKAMAX_AVAILABLE or not config.enabled: | ||
| from flax.nnx.nn.attention import dot_product_attention | ||
|
|
||
| return dot_product_attention |
There was a problem hiding this comment.
The function create_splash_attention_fn currently returns dot_product_attention when Splash Attention is disabled or unavailable. This is inconsistent with the new test test_create_fn_returns_none_when_disabled in tests/test_splash_attention.py, which expects None. Returning None would make the function's behavior more explicit, align with the test's expectation, and simplify the logic at call sites.
I suggest changing the implementation to return None when splash attention is not used. You'll also need to update the function's return type hint and docstring. This change will make the test test_create_fn_returns_none_when_disabled pass as written.
| ) -> Callable[..., Float[Array, "batch heads seq head_dim"]]: | |
| """Create a splash attention function compatible with nnx.MultiHeadAttention. | |
| Args: | |
| config (SplashAttentionConfig): Splash attention configuration. | |
| num_heads (int): Number of attention heads. | |
| head_dim (int): Dimension of each attention head. | |
| Returns: | |
| Callable: An attention function. Returns splash attention if enabled and available, | |
| otherwise returns the default dot_product_attention. | |
| """ | |
| if not _TOKAMAX_AVAILABLE or not config.enabled: | |
| from flax.nnx.nn.attention import dot_product_attention | |
| return dot_product_attention | |
| ) -> Callable[..., Float[Array, "batch heads seq head_dim"]] | None: | |
| """Create a splash attention function compatible with nnx.MultiHeadAttention. | |
| Args: | |
| config (SplashAttentionConfig): Splash attention configuration. | |
| num_heads (int): Number of attention heads. | |
| head_dim (int): Dimension of each attention head. | |
| Returns: | |
| Callable | None: An attention function. Returns splash attention if enabled and available, | |
| otherwise returns None. | |
| """ | |
| if not _TOKAMAX_AVAILABLE or not config.enabled: | |
| return None |
src/jimm/common/vit.py
Outdated
| attention_fn = None | ||
| if splash_attention_config is not None: | ||
| attention_fn = create_splash_attention_fn( | ||
| splash_attention_config, | ||
| num_heads=num_heads, | ||
| head_dim=hidden_size // num_heads, | ||
| ) | ||
|
|
||
| attn_kwargs: dict = { | ||
| "num_heads": num_heads, | ||
| "in_features": hidden_size, | ||
| "broadcast_dropout": False, | ||
| "decode": False, | ||
| "deterministic": False, | ||
| "dtype": dtype, | ||
| "param_dtype": param_dtype, | ||
| "rngs": rngs, | ||
| "kernel_init": nnx.with_partitioning(nnx.initializers.xavier_uniform(), mesh_rules("map_attn_in", "map_attn_out")), | ||
| "bias_init": nnx.with_partitioning( | ||
| nnx.initializers.zeros_init(), | ||
| mesh_rules( | ||
| "map_attn_out", | ||
| ), | ||
| ), | ||
| ) | ||
| } | ||
| if attention_fn is not None: | ||
| attn_kwargs["attention_fn"] = attention_fn | ||
| self.attn = nnx.MultiHeadAttention(**attn_kwargs) |
There was a problem hiding this comment.
The initialization of self.attn can be simplified. Instead of creating a dictionary attn_kwargs and conditionally adding attention_fn, you can determine attention_fn first and then pass it directly to the nnx.MultiHeadAttention constructor. Since nnx.MultiHeadAttention handles attention_fn=None by using the default attention, the logic can be made more concise and readable.
Assuming the suggested change in splash_attention.py is made (where create_splash_attention_fn returns None if disabled), you can refactor this block to be cleaner.
attention_fn = None
if splash_attention_config is not None:
attention_fn = create_splash_attention_fn(
splash_attention_config,
num_heads=num_heads,
head_dim=hidden_size // num_heads,
)
self.attn = nnx.MultiHeadAttention(
num_heads=num_heads,
in_features=hidden_size,
broadcast_dropout=False,
decode=False,
deterministic=False,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), mesh_rules("map_attn_in", "map_attn_out")),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
mesh_rules(
"map_attn_out",
),
),
attention_fn=attention_fn,
)|
Acknowledged. I am now performing a full code review of the pull request. I will post my findings shortly. |
There was a problem hiding this comment.
Code Review
This pull request introduces experimental support for Splash Attention, which is a great addition for performance on TPUs. The changes are well-structured, adding a new configuration and integrating it throughout the model hierarchy. I've found a critical issue in the implementation of the create_splash_attention_fn which would cause a new test to fail, along with a related inconsistency in how the attention function is handled in different models. My review includes suggestions to fix these issues to ensure correctness and improve maintainability.
| if not _TOKAMAX_AVAILABLE or not config.enabled: | ||
| from flax.nnx.nn.attention import dot_product_attention | ||
|
|
||
| return dot_product_attention |
There was a problem hiding this comment.
The function should return None when Splash Attention is disabled or unavailable. This allows nnx.MultiHeadAttention to fall back to its default dot_product_attention. The current implementation returns dot_product_attention directly, which is inconsistent with the new test test_create_fn_returns_none_when_disabled and can lead to less clear control flow in the model definitions.
| if not _TOKAMAX_AVAILABLE or not config.enabled: | |
| from flax.nnx.nn.attention import dot_product_attention | |
| return dot_product_attention | |
| if not _TOKAMAX_AVAILABLE or not config.enabled: | |
| return None |
| ) -> Callable[..., Float[Array, "batch heads seq head_dim"]]: | ||
| """Create a splash attention function compatible with nnx.MultiHeadAttention. | ||
|
|
||
| Args: | ||
| config (SplashAttentionConfig): Splash attention configuration. | ||
| num_heads (int): Number of attention heads. | ||
| head_dim (int): Dimension of each attention head. | ||
|
|
||
| Returns: | ||
| Callable: An attention function. Returns splash attention if enabled and available, | ||
| otherwise returns the default dot_product_attention. | ||
| """ |
There was a problem hiding this comment.
The return type hint and docstring for create_splash_attention_fn should be updated to reflect that it can return None when splash attention is disabled. This makes the function signature more accurate and consistent with the proposed implementation change.
| ) -> Callable[..., Float[Array, "batch heads seq head_dim"]]: | |
| """Create a splash attention function compatible with nnx.MultiHeadAttention. | |
| Args: | |
| config (SplashAttentionConfig): Splash attention configuration. | |
| num_heads (int): Number of attention heads. | |
| head_dim (int): Dimension of each attention head. | |
| Returns: | |
| Callable: An attention function. Returns splash attention if enabled and available, | |
| otherwise returns the default dot_product_attention. | |
| """ | |
| ) -> Callable[..., Float[Array, "batch heads seq head_dim"]] | None: | |
| """Create a splash attention function compatible with nnx.MultiHeadAttention. | |
| Args: | |
| config (SplashAttentionConfig): Splash attention configuration. | |
| num_heads (int): Number of attention heads. | |
| head_dim (int): Dimension of each attention head. | |
| Returns: | |
| Callable | None: An attention function if splash attention is enabled and available, | |
| otherwise `None`. | |
| """ |
| attention_fn = ( | ||
| create_splash_attention_fn( | ||
| splash_attention_config, | ||
| num_heads=num_heads, | ||
| head_dim=hidden_size // num_heads, | ||
| ) | ||
| if splash_attention_config is not None | ||
| else nnx.dot_product_attention | ||
| ) |
There was a problem hiding this comment.
The logic for determining attention_fn can be simplified and made more consistent with the pattern used in src/jimm/common/vit.py. By initializing attention_fn to None and only creating the splash function if splash_attention_config is provided, the code becomes more explicit and readable. nnx.MultiHeadAttention will correctly use its default when attention_fn is None.
This change depends on create_splash_attention_fn returning None when disabled, as suggested in another comment.
attention_fn = None
if splash_attention_config is not None:
attention_fn = create_splash_attention_fn(
splash_attention_config,
num_heads=num_heads,
head_dim=hidden_size // num_heads,
)|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces experimental support for TPU-optimized Splash Attention using the tokamax library. It adds a new configuration class, integrates Splash Attention into the model codebase, and updates documentation and dependencies accordingly. The changes are well-structured, with new functionality encapsulated in src/jimm/common/splash_attention.py. The integration into existing models is done by plumbing the new configuration through the model hierarchy, which is a clean approach. New tests are added to ensure the changes don't break existing functionality when Splash Attention is disabled. My review focuses on improving code clarity and maintainability in the new splash_attention module and related components. I've suggested simplifying some logic and improving naming for better readability. Overall, this is a great addition to the library.
| if not _TOKAMAX_AVAILABLE or not config.enabled: | ||
| from flax.nnx.nn.attention import dot_product_attention | ||
|
|
||
| return dot_product_attention |
There was a problem hiding this comment.
The import of dot_product_attention is performed inside this function. It's generally better practice to place all imports at the top of the file for improved readability and to avoid repeated import operations. Since flax is a core dependency of the project, you can add from flax.nnx.nn.attention import dot_product_attention at the top of the file and simplify this block.
| if not _TOKAMAX_AVAILABLE or not config.enabled: | |
| from flax.nnx.nn.attention import dot_product_attention | |
| return dot_product_attention | |
| if not _TOKAMAX_AVAILABLE or not config.enabled: | |
| return dot_product_attention |
src/jimm/common/splash_attention.py
Outdated
| """ | ||
| seq_len = query.shape[2] | ||
| kernel = _create_splash_kernel(seq_len, num_heads, head_dim, config) | ||
| return jax.vmap(lambda q, k, v: kernel(q, k, v))(query, key, value) |
There was a problem hiding this comment.
src/jimm/common/vit.py
Outdated
| attention_fn = None | ||
| if splash_attention_config is not None: | ||
| attention_fn = create_splash_attention_fn( | ||
| splash_attention_config, | ||
| num_heads=num_heads, | ||
| head_dim=hidden_size // num_heads, | ||
| ) | ||
|
|
||
| attn_kwargs: dict = { | ||
| "num_heads": num_heads, | ||
| "in_features": hidden_size, | ||
| "broadcast_dropout": False, | ||
| "decode": False, | ||
| "deterministic": False, | ||
| "dtype": dtype, | ||
| "param_dtype": param_dtype, | ||
| "rngs": rngs, | ||
| "kernel_init": nnx.with_partitioning(nnx.initializers.xavier_uniform(), mesh_rules("map_attn_in", "map_attn_out")), | ||
| "bias_init": nnx.with_partitioning( | ||
| nnx.initializers.zeros_init(), | ||
| mesh_rules( | ||
| "map_attn_out", | ||
| ), | ||
| ), | ||
| ) | ||
| } | ||
| if attention_fn is not None: | ||
| attn_kwargs["attention_fn"] = attention_fn | ||
| self.attn = nnx.MultiHeadAttention(**attn_kwargs) |
There was a problem hiding this comment.
The logic for creating and assigning the attention_fn can be simplified. You can directly create and add the attention_fn to the attn_kwargs dictionary within a single conditional block, which removes the need for the intermediate attention_fn variable and a second if statement. This makes the code more direct and readable.
attn_kwargs: dict = {
"num_heads": num_heads,
"in_features": hidden_size,
"broadcast_dropout": False,
"decode": False,
"deterministic": False,
"dtype": dtype,
"param_dtype": param_dtype,
"rngs": rngs,
"kernel_init": nnx.with_partitioning(nnx.initializers.xavier_uniform(), mesh_rules("map_attn_in", "map_attn_out")),
"bias_init": nnx.with_partitioning(
nnx.initializers.zeros_init(),
mesh_rules(
"map_attn_out",
),
),
}
if splash_attention_config is not None:
attn_kwargs["attention_fn"] = create_splash_attention_fn(
splash_attention_config,
num_heads=num_heads,
head_dim=hidden_size // num_heads,
)
self.attn = nnx.MultiHeadAttention(**attn_kwargs)
tests/test_splash_attention.py
Outdated
| assert config.block_kv == 256 | ||
|
|
||
|
|
||
| def test_create_fn_returns_none_when_disabled() -> None: |
There was a problem hiding this comment.
The name of this test function, test_create_fn_returns_none_when_disabled, is a bit misleading as the function doesn't return None but rather the default dot_product_attention function. Renaming it would improve clarity.
| def test_create_fn_returns_none_when_disabled() -> None: | |
| def test_create_fn_returns_default_when_disabled() -> None: |
…d the test function. Also simplified the splash attention fn
This reverts commit e496f89.
This pull request introduces experimental support for TPU-optimized Splash Attention using the
tokamaxlibrary. It adds a new configuration class, integrates Splash Attention into the model codebase, and updates documentation and dependencies accordingly. The changes allow users to enable Splash Attention for improved performance on supported hardware (e.g., TPU), and provide a flexible API for configuration.Splash Attention Integration:
SplashAttentionConfigdataclass and kernel creation utilities insrc/jimm/common/splash_attention.pyto encapsulate configuration and provide a factory for Splash Attention functions.splash_attention_configparameter and wiring the custom attention function intonnx.MultiHeadAttentionwhen enabled. [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] [11] [12] [13] [14]Model and API Updates:
SplashAttentionConfigin the package API (src/jimm/__init__.py) for easy user access. [1] [2]from_pretrainedmethods to accept and propagatesplash_attention_config. [1] [2] [3] [4] [5]Documentation and Dependency Updates:
README.mddocumenting Splash Attention usage, configuration, and hardware caveats.tokamax>=0.0.9as a dependency and updated the minimum Python version to 3.11 inpyproject.toml.These changes enable experimental, hardware-optimized attention for users on supported devices, while maintaining backward compatibility and clear configuration options.